import os
import numpy as np
import time
import copy
import subprocess
import tbe
from te.platform import get_soc_spec
from te.platform import te_set_version
import shutil
from concurrent.futures import ProcessPoolExecutor
from tools.acl_execute import acl_op
from tools.atc_om import atc
from tools.constant import acl_dtype
from tools.profiling import op_profiling
from tools.dataloader import save_data, load_data
from tools.acl_process import acl_run


class Executor(object):
    def __init__(self):
        self.current_path = os.path.abspath(os.path.dirname(__file__))
        self.dtype2np = {"float16": np.float16, "float32": np.float32,
                         "float": np.float32, "int8": np.int8, "int16": np.int16, "int32": np.int32, "int64": np.int64,
                         "uint8": np.uint8, "uint16": np.uint16}
        self.prof_res = [{"task_time": None, "aicore_time": None}]
        self.dump_graph = False
        self.is_dynamic = False
        self.save_profiling = False
        self.profiling = True
        self.op_debug_level = 0
        self.opp_path = os.environ["ASCEND_OPP_PATH"]
        self.op_tiling_path = self.opp_path + "/op_impl/built-in/ai_core/tbe/op_tiling/liboptiling.so"

    def set_version(self, soc_version):
        if not soc_version:
            soc_version = get_soc_spec("SOC_VERSION")
        te_set_version(soc_version)
        print("soc_version:", soc_version)

    def set_device(self, device_id):
        if device_id is not None:
            os.environ["DEVICE_ID"] = str(device_id)
            print("device_id:", device_id)

    def set_graph(self):
        if self.dump_graph:
            os.environ["DUMP_GE_GRAPH"] = '1'
        else:
            os.environ.unsetenv("DUMP_GE_GRAPH")

    def compile_op(self, op_func, params):
        with tbe.common.context.op_context.OpContext():
            op_func(*params)

    def check_ndarray(self, data):
        if not isinstance(data, np.ndarray):
            raise RuntimeError("input data should be ndarray")

    def init_dir(self, dir_name):
        dir_path = os.path.join(self.current_path, dir_name)
        if not os.path.exists(dir_path):
            os.mkdir(dir_path)

        return dir_path

    def trans_dtype(self, dtype):
        dtype = self.dtype2np[dtype]
        return dtype

    def clear_data(self, op_dir):
        if os.path.exists(op_dir):
            shutil.rmtree(op_dir)

    def get_attr_type(self, value):
        dtype = None
        if isinstance(value, int):
            dtype = "int"
        if isinstance(value, float):
            dtype = "float"
        if isinstance(value, str):
            dtype = "string"
        if isinstance(value, bool):
            dtype = "bool"
        if isinstance(value, (list, tuple)):
            v = value[0]
            if isinstance(v, int):
                dtype = "list_int"
            if isinstance(v, float):
                dtype = "list_float"
            if isinstance(v, bool):
                dtype = "list_bool"
            if isinstance(v, str):
                dtype = "list_string"
            if isinstance(v, list) and isinstance(v[0], int):
                dtype = "list_list_int"

        return dtype

    def check_info(self, info):
        if "shape" not in info:
            raise RuntimeError("data info must includes shape")
        if "dtype" not in info:
            raise RuntimeError("data info must includes dtype")

        dtype = info["dtype"]
        if dtype == "float32":
            dtype = "float"
        if dtype not in acl_dtype:
            raise RuntimeError("dtype may be wrong, please check")

    def process_data_info(self, data_infos):
        res = []
        new_data_infos = copy.deepcopy(data_infos)
        for info in new_data_infos:
            self.check_info(info)
            info["type"] = info["dtype"]
            if info["dtype"] == "float":
                info["type"] = "float32"
            info.pop("dtype")

            if "shape_range" in info:
                self.is_dynamic = True
            res.append(info)

        return res

    def get_data_info(self, params):
        inputs_info = params.get("input_desc")
        outputs_info = params.get("output_desc")
        return inputs_info, outputs_info

    def get_output_path(self, dir_path, num):
        res = []
        for i in range(num):
            data_path = os.path.join(dir_path, "output_%s.npy" % i)
            res.append(data_path)
        return res

    def start_profiling(self, op_dir, device_id):
        if self.profiling:
            try:
                self.prof_res = op_profiling.op_profiling(op_dir, device_id)
                if self.prof_res:
                    op_profiling.format_print(self.prof_res)
                    # print(f'aicore_time {self.prof_res[0]["aicore_time"]}')

            except Exception as e:
                print(e)

    def acl_kernel(self, op_func, op_params, inputs_data, outputs_info, kernel_name, soc_version=None, device_id=0):
        self.set_version(soc_version)
        self.set_device(device_id)

        os.environ["ASCEND_OPP_PATH"] = "home"

        op_func(*op_params)
        inputs_info = []
        for i in range(len(inputs_data)):
            info = {}
            data = inputs_data[i]
            info["type"] = str(data.dtype)
            info["shape"] = data.shape
            info["format"] = op_params[i].get("format", "ND")
            inputs_info.append(info)

        outputs_info = self.process_data_info(outputs_info)

        with ProcessPoolExecutor() as pool_executor:
            future = pool_executor.submit(acl_op.run, kernel_name, inputs_data, inputs_info, outputs_info, None,
                                          device_id)
            result = future.result()
            return result

    def acl_om(self, op_type, inputs_info, inputs_data, outputs_info, attr_dict, real_out=None, soc_version="Ascend910",
               device_id=0):
        self.set_version(soc_version)
        self.set_device(device_id)

        dir_path = os.path.join(self.current_path, op_type)
        self.clear_data(dir_path)
        op_dir = self.init_dir(op_type)

        params, return_code = self.atc_op(op_type, inputs_info, outputs_info, attr_dict, soc_version=soc_version)
        inputs_path = save_data(op_dir, inputs_data)
        if return_code != 0:
            raise RuntimeError("atc om failed")

        attr = params.get("attr")
        inputs_info, outputs_info = self.get_data_info(params)
        if real_out:
            outputs_info = self.process_data_info(real_out)

        try:
            acl_run(op_type, inputs_path, inputs_info, outputs_info, attr, device_id, op_dir, int(self.is_dynamic))
            res_path = self.get_output_path(op_dir, len(outputs_info))
            result = load_data(res_path)

            self.start_profiling(op_dir, device_id)
            if not (self.dump_graph + self.op_debug_level + self.save_profiling):
                self.clear_data(op_dir)

            return result, self.prof_res
        except Exception as e:
            print(e)

    def atc_op(self, op_type, inputs_info, outputs_info, attr_dict, soc_version="Ascend910"):
        self.set_graph()
        op_dir = self.init_dir(op_type)

        params = {}
        attr = []
        params["op"] = op_type
        params["input_desc"] = self.process_data_info(inputs_info)
        params["output_desc"] = self.process_data_info(outputs_info)

        for key in attr_dict.keys():
            tmp_dict = {}
            value = attr_dict[key]
            dtype = self.get_attr_type(value)
            tmp_dict["name"] = key
            tmp_dict["value"] = value
            tmp_dict["type"] = dtype
            attr.append(tmp_dict)

        params["attr"] = attr
        atc.op_debug_level = int(self.op_debug_level)
        os.chdir(op_dir)
        return_code = atc.generate_om(params, op_dir, soc_version)
        os.chdir(self.current_path)
        if return_code != 0:
            print("atc om failed")

        return params, return_code


executor = Executor()
